{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training PassGAN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This Jupyter notebook is mainly for debugging, has same functionality to \"train.py\". However, the parameters are configured small."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append(os.getcwd())\n",
    "\n",
    "import time\n",
    "import pickle\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "import utils\n",
    "import tflib as lib\n",
    "import tflib.ops.linear\n",
    "import tflib.ops.conv1d\n",
    "import tflib.plot\n",
    "import models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.13.1\n"
     ]
    }
   ],
   "source": [
    "# TensorFlow virsion\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To simplify migration process between Jupyter Notebook and .py format, we created virtual ArgumentParser class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Class to virtualize ArgumentParser\n",
    "class VirtualArgparse:\n",
    "    \n",
    "    # Path to dataset\n",
    "    training_data = \"data/train.txt\"\n",
    "    \n",
    "    # Name of directory to output\n",
    "    output_dir = \"pretrained\"\n",
    "    \n",
    "    save_every = 10   #5000\n",
    "    iters = 100   #200000\n",
    "    batch_size = 64\n",
    "    seq_length = 10\n",
    "    layer_dim = 128\n",
    "    critic_iters = 10\n",
    "    lamb = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Virtualize ArgumentParser instance\n",
    "args = VirtualArgparse"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating Directories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.isdir(args.output_dir):\n",
    "    os.makedirs(args.output_dir)\n",
    "\n",
    "if not os.path.isdir(os.path.join(args.output_dir, 'checkpoints')):\n",
    "    os.makedirs(os.path.join(args.output_dir, 'checkpoints'))\n",
    "\n",
    "if not os.path.isdir(os.path.join(args.output_dir, 'samples')):\n",
    "    os.makedirs(os.path.join(args.output_dir, 'samples'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Importing Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loaded 91 lines in dataset\n"
     ]
    }
   ],
   "source": [
    "lines, charmap, inv_charmap = utils.load_dataset(\n",
    "    path=args.training_data,\n",
    "    max_length=args.seq_length)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating Dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of unique characters in dataset: 51\n"
     ]
    }
   ],
   "source": [
    "# Pickle to avoid encoding errors with json\n",
    "with open(os.path.join(args.output_dir, 'charmap.pickle'), 'wb') as f:\n",
    "    pickle.dump(charmap, f)\n",
    "\n",
    "with open(os.path.join(args.output_dir, 'charmap_inv.pickle'), 'wb') as f:\n",
    "    pickle.dump(inv_charmap, f)\n",
    "    \n",
    "print(\"Number of unique characters in dataset: {}\".format(len(charmap)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modeling Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "real_inputs_discrete = tf.placeholder(tf.int32, shape=[args.batch_size, args.seq_length])\n",
    "real_inputs = tf.one_hot(real_inputs_discrete, len(charmap))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"Placeholder:0\", shape=(64, 10), dtype=int32)\n",
      "Tensor(\"one_hot:0\", shape=(64, 10, 51), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "print(real_inputs_discrete)\n",
    "print(real_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n"
     ]
    }
   ],
   "source": [
    "fake_inputs = models.Generator(args.batch_size, args.seq_length, args.layer_dim, len(charmap))\n",
    "fake_inputs_discrete = tf.argmax(fake_inputs, fake_inputs.get_shape().ndims-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modeling Discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "disc_real = models.Discriminator(real_inputs, args.seq_length, args.layer_dim, len(charmap))\n",
    "disc_fake = models.Discriminator(fake_inputs, args.seq_length, args.layer_dim, len(charmap))\n",
    "\n",
    "disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)\n",
    "gen_cost = -tf.reduce_mean(disc_fake)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n"
     ]
    }
   ],
   "source": [
    "# WGAN lipschitz-penalty\n",
    "alpha = tf.random_uniform(\n",
    "    shape=[args.batch_size,1,1],\n",
    "    minval=0.,\n",
    "    maxval=1.\n",
    ")\n",
    "\n",
    "differences = fake_inputs - real_inputs\n",
    "interpolates = real_inputs + (alpha*differences)\n",
    "gradients = tf.gradients(models.Discriminator(interpolates, args.seq_length, args.layer_dim, len(charmap)), [interpolates])[0]\n",
    "slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2]))\n",
    "gradient_penalty = tf.reduce_mean((slopes-1.)**2)\n",
    "disc_cost += args.lamb * gradient_penalty\n",
    "\n",
    "gen_params = lib.params_with_name('Generator')\n",
    "disc_params = lib.params_with_name('Discriminator')\n",
    "\n",
    "gen_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(gen_cost, var_list=gen_params)\n",
    "disc_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(disc_cost, var_list=disc_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset iterator\n",
    "def inf_train_gen():\n",
    "    while True:\n",
    "        np.random.shuffle(lines)\n",
    "        for i in range(0, len(lines)-args.batch_size+1, args.batch_size):\n",
    "            yield np.array(\n",
    "                [[charmap[c] for c in l] for l in lines[i:i+args.batch_size]],\n",
    "                dtype='int32'\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "validation set JSD for n=1: 0.49999999999999994\n",
      "validation set JSD for n=2: 0.49999999999999983\n",
      "validation set JSD for n=3: 0.5000000000000001\n",
      "validation set JSD for n=4: 0.5000000000000002\n"
     ]
    }
   ],
   "source": [
    "# During training we monitor JS divergence between the true & generated ngram\n",
    "# distributions for n=1,2,3,4. To get an idea of the optimal values, we\n",
    "# evaluate these statistics on a held-out set first.\n",
    "true_char_ngram_lms = [utils.NgramLanguageModel(i+1, lines[10*args.batch_size:], tokenize=False) for i in range(4)]\n",
    "validation_char_ngram_lms = [utils.NgramLanguageModel(i+1, lines[:10*args.batch_size], tokenize=False) for i in range(4)]\n",
    "for i in range(4):\n",
    "    print(\"validation set JSD for n={}: {}\".format(i+1, true_char_ngram_lms[i].js_with(validation_char_ngram_lms[i])))\n",
    "true_char_ngram_lms = [utils.NgramLanguageModel(i+1, lines, tokenize=False) for i in range(4)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TensorFlow Session"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting TensorFlow session...\n",
      "Local current time : Mon Mar 25 04:51:55 2019\n",
      "10 / 100 (10.0%)\n",
      "20 / 100 (20.0%)\n",
      "30 / 100 (30.0%)\n",
      "40 / 100 (40.0%)\n",
      "50 / 100 (50.0%)\n",
      "60 / 100 (60.0%)\n",
      "70 / 100 (70.0%)\n",
      "80 / 100 (80.0%)\n",
      "90 / 100 (90.0%)\n",
      "100 / 100 (100.0%)\n",
      "...Training done.\n",
      "Ending TensorFlow session.\n",
      "Local current time : Mon Mar 25 04:52:47 2019\n"
     ]
    }
   ],
   "source": [
    "with tf.Session() as session:\n",
    "\n",
    "    # Time stamp\n",
    "    localtime = time.asctime( time.localtime(time.time()) )\n",
    "    print(\"Starting TensorFlow session...\")\n",
    "    print(\"Local current time :\", localtime)\n",
    "    \n",
    "    # Start TensorFlow session...\n",
    "    session.run(tf.global_variables_initializer())\n",
    "\n",
    "    def generate_samples():\n",
    "        samples = session.run(fake_inputs)\n",
    "        samples = np.argmax(samples, axis=2)\n",
    "        decoded_samples = []\n",
    "        for i in range(len(samples)):\n",
    "            decoded = []\n",
    "            for j in range(len(samples[i])):\n",
    "                decoded.append(inv_charmap[samples[i][j]])\n",
    "            decoded_samples.append(tuple(decoded))\n",
    "        return decoded_samples\n",
    "\n",
    "    gen = inf_train_gen()\n",
    "\n",
    "    for iteration in range(args.iters + 1):\n",
    "        start_time = time.time()\n",
    "\n",
    "        # Train Generator\n",
    "        if iteration > 0:\n",
    "            _ = session.run(gen_train_op)\n",
    "\n",
    "        # Train Discriminator\n",
    "        for i in range(args.critic_iters):\n",
    "            _data = next(gen)\n",
    "            _disc_cost, _ = session.run(\n",
    "                [disc_cost, disc_train_op],\n",
    "                feed_dict={real_inputs_discrete:_data}\n",
    "            )\n",
    "\n",
    "        lib.plot.output_dir = args.output_dir\n",
    "        lib.plot.plot('time', time.time() - start_time)\n",
    "        lib.plot.plot('train disc cost', _disc_cost)\n",
    "\n",
    "        # Output to text file after every 100 samples\n",
    "        if iteration % 100 == 0 and iteration > 0:\n",
    "\n",
    "            samples = []\n",
    "            for i in range(10):\n",
    "                samples.extend(generate_samples())\n",
    "\n",
    "            for i in range(4):\n",
    "                lm = utils.NgramLanguageModel(i+1, samples, tokenize=False)\n",
    "                lib.plot.plot('js{}'.format(i+1), lm.js_with(true_char_ngram_lms[i]))\n",
    "\n",
    "            with open(os.path.join(args.output_dir, 'samples', 'samples_{}.txt').format(iteration), 'w') as f:\n",
    "                for s in samples:\n",
    "                    s = \"\".join(s)\n",
    "                    f.write(s + \"\\n\")\n",
    "\n",
    "        if iteration % args.save_every == 0 and iteration > 0:\n",
    "            model_saver = tf.train.Saver()\n",
    "            model_saver.save(session, os.path.join(args.output_dir, 'checkpoints', 'checkpoint_{}.ckpt').format(iteration))\n",
    "            print(\"{} / {} ({}%)\".format(iteration, args.iters, iteration/args.iters*100.0 ))\n",
    "\n",
    "        if iteration == args.iters:\n",
    "            print(\"...Training done.\")\n",
    "        \n",
    "        #if iteration % 100 == 0:\n",
    "            #lib.plot.flush()\n",
    "\n",
    "        #lib.plot.tick()\n",
    "        \n",
    "# Time stamp\n",
    "localtime = time.asctime( time.localtime(time.time()) )\n",
    "print(\"Ending TensorFlow session.\")\n",
    "print(\"Local current time :\", localtime)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_tensorflow_p36)",
   "language": "python",
   "name": "conda_tensorflow_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
